{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "vsx_ = s_[:,0].data.numpy()\n",
    "y_ = s_[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)\n",
    "    vs = torch.longTensor(len_horizon)*3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([11.6842, -5.5343, -0.3132,  0.5882, -0.2611,  0.0123],\n",
       "       grad_fn=<CatBackward>)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rc =vehicle_model(1.)\n",
    "rc.forward(path[0],controls_dt[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1.1684e+01, -5.5343e+00, -1.5506e+01,  5.8816e-01, -2.6109e-01,\n",
       "         1.2283e-02], grad_fn=<CatBackward>)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rc_ =vehicle_model(1./200)\n",
    "rc_.forward(path[0],controls_dt[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0,  3,  6,  9, 12, 15, 18, 21, 24, 27, 30])\n",
      "T=\n",
      "0\n",
      "0.51528883\n",
      "0.000494946\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "1\n",
      "0.3190604\n",
      "0.000483717\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "2\n",
      "0.29562992\n",
      "0.0010066641\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "3\n",
      "0.43493512\n",
      "0.0042665284\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "4\n",
      "0.8486596\n",
      "0.011120641\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "5\n",
      "1.4111841\n",
      "0.018706586\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "6\n",
      "1.7719783\n",
      "0.026163865\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "7\n",
      "1.6556798\n",
      "0.031296644\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "8\n",
      "1.5561366\n",
      "0.02818155\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "9\n",
      "1.1107941\n",
      "0.02765983\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "10\n",
      "9.285178\n",
      "0.045080848\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "11\n",
      "5.9877396\n",
      "0.08164187\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "12\n",
      "2.6366045\n",
      "0.13417928\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "13\n",
      "0.56148475\n",
      "0.13050494\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "14\n",
      "0.5949676\n",
      "0.11052068\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "15\n",
      "1.0449328\n",
      "0.05780882\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "16\n",
      "0.6465386\n",
      "0.027876774\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "17\n",
      "0.24150792\n",
      "0.00922889\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "18\n",
      "0.2030375\n",
      "0.0012901748\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "19\n",
      "0.24905509\n",
      "0.005600761\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7f0f96af4630>"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAZzElEQVR4nO3df5DcdX3H8eebH4GQEKneoZUkHG1RcfwVWB0tLVVJLCoEoXZqUEdXxrRa8BfcyY+pilZr76zKqLVmDIeOcNYBVGCoCra0Y6dqLoC/CFRGg4RCubP1ByEQUt/947ubbDa7e7u3393v5/P5vh4zmb293dv9ZH+8v+99fT7f75q7IyIi6Tmo6AGIiMhgqMCLiCRKBV5EJFEq8CIiiVKBFxFJ1CFF3OnIyIiPjY0VcdciItHaunXrvLuPdnv9Qgr82NgYs7OzRdy1iEi0zOzeXq6viEZEJFEq8CIiiVKBFxFJlAq8iEiiVOBFRBKlAi8iQzU3B1u2ZKe9XCa9U4EXkaGZmYFjj4V167LTmZnuLpPFsSIOF1ypVFzr4EXKZW4uK9y7du373dKlcG9tZXe7y0a73q0nfWa21d0r3V5fHbyI5O6AqGV+nkc/MMUJB93NhUzxJOZ5EvNc8JspHvjXuw+4DODQQ2H79sL+C0koZE9WEUnXzAycey4sWQJHPjbP9WdNs+ZpO1n1icv464Nu5eXctPe6H3hsgsc+cyuH3XLT3suWsZOdLOOLu6uMjY3sve7cXFbwx8bU1XdLBV5EcjM3lxX3I3bNU901zRHsZM3MZey88L0sm5xkz0HrueSiF3PN4VX27IFzzoI1710P179472VHHrST9+2e4JyzYNSqMDXNtSuqvP6dIyxZArt3w+bNsGFD0f/b8CmDF5HcbNmSTZK++ZdTTDHB+3gv/3fYMs6+ocqadVk33qkTn5uDHXfMc/y3pll+fhWmp2FigksOmeRv9ozvvV5Z8/leM3h18CKSj/l5TrhxmiMfqzJNFYBpqjxy0Ahve96+q42Oti/Mo6Mwum4E1tWKebXKfffB169Yz4V7ppimys8Z2ZvPl63A90oFXkTyMT3N8vdPcP0GOPkr42w6dJzHH8/ilEUX4pERDv+rcf7401N8iIm2+by0pgIvIotSj1qOO3KekRumYf16ANZUq9x7eX4ToqOjcNInq1xyHvvn86PjC/9xyanAi0jPGlfKvHXnNB/aM5FdMJ4V3VHyjU/+5M9HOOXscXbcMc/D31rGmnPWw9QUVKswok6+HRV4EelJ80qZa1gPh8C7zqgyyFK7Xz4/NQUTEzz8MGw7fVxLJ9vIZUcnMzvKzK4xs7vMbJuZvSiP2xWR8GzfnnXuVaaZYoIzuZ5PHTHOT389xE66WuX2DZOcMFnlz06d52NPneLaz8wP7/4jkVcHfznwNXd/tZktAY7I6XZFJDBjY9la9MaVMo8/nv1+WOZ8hJO/Ms6uR+HCR7MJ2EvOg1POHlcn36DvDt7MngCcAmwGcPfd7v6Lfm9XRMI0avP8+6umWHo4bFoxziNLR/pbKbMI9U8RkG1gxplk5vCqDm3QJI+I5jhgDpg2s9vN7LNmtqz5Sma20cxmzWx2TscCFYnX9DRrZibYNjHNLbdkOxwNe6/S+qcIgJ8zwkcYZ88eOOHGKZhXVFOXR4E/BDgR+LS7rwF2Ahc1X8ndN7l7xd0ro/oMJRKf+flscnP9epicZPn5VZ7//GImN0dHs/X1S5fCihXZ6fVnZevwmZ4e/oAClUcGvwPY4e7fqZ2/hhYFXkQiVztsALB3OWSRNmyAtWsb1ttbFdaQbYC0hBLIocC7+4Nmdp+ZPd3d7wZOBe7sf2giEpRqdf/TAOx/2IORbMNTW0IJBLEhKlJeq2jOB66qraD5CRDOK0BE+jM/n3Xv1WoUBXP+jCq77oOlA16XH4Nc1sG7+x21fP057v4qd//fPG5XRAJQj2YiyLZnZmD1iSM8+/PjrFkDt59T7klX7ckqIp0FGM20Ut/Ddteu7N+byVb7PHw8LL8s/E8eg6ACLyIHaoxlRkaiiGbqa+Pr3+s6TZXDD4Oz/6DKmkJHVhx9J6uIHCiiWKaucW08ZOvj/+6gcVatIpt4LWFUow5eRA4USSzTqL42/txzsy/srh+LfuSGsJZ3DpMKvIjsE9mKmWYHrI0fBebj21jlRRGNiOwTYTTTbHSU/fewbZxDKFlUow5eRPaJMJrpWmB74g6DCryI7BPJiplezc3BjudVOf49sDzFjVcbimhSUD8I1Px8+59FOkn4tTIzA8ceCy/50xGOnhpn5uby7N+qDj5WjZNhjR89ofXP9evV1zU3r3OWcks0vmje+Qmy8+vW1L4oPPHXvwp86BoLMbQu6q1y0+afm9/AzX+vYl9uiWbvzTs/QbaEctffT8Mn0tugNVOBD1277rzxDdmcm7b6ufkN3HjaXPzV3ZdH5MsiF9K88xNk6+OXvrUKq0hug3YAdx/6v5NOOsmlg7k598nJ7LTdz4O6P/fsZ8hOW10u6Wh+rhN09dXuS5e6r1iRnV59dcOFkb22gVnvodaqwIei8YVW9JtuoYIv6YiswC3WQw+5f/e72el+Intt91rgFdGEYqFMfZiaI5/G8Si+iV+EBxLr1/5fDNKg6PfagKnAF63+Zlu/Pjsf4puueU9A5fVxS3TFzKKE9l7LmQp8EdotcYzhhdbc8cQ2fkm+a12URBsVFfgihBTH9KpTfAPJvlGSknjXuiiJNioq8MMUQxzTq+bxJ/pGSYI2vi2lfBgDFfhhKkPxU0cfrjK8/no0M5Pt2bpkyQi7d4+z+RnZIYdToQI/aI0FLrY4ZjHU0YerDK+/HrQ6jMHEm+Y5465plp+fRkOig40NWuPxtevFL4EXTteqVZic3L+jT/SgVsGqP+ZQvtdfB/XDGDR6o0+z/P1xHw+/kTr4QSt716SOvnh6zFtqdRiDK63KuxPK4lXgByXxY3wsmjL64St7k9FGq+9wndw8wvIN6bxfc4tozOxgM7vdzG7M6zajlsBXnw1Ec0ylx2kwGqOwMkaDXdqwAe69F265JTvdsIGkYsQ8O/i3A9uAFTneZnxaLYWU9lp1l+rq+6dYpmsHHMYgocculwJvZiuBVwIfBN6Vx21GK6EXx1C02g9Aj2H/FMssXkKPXV4d/MeBCeDIdlcws43ARoDVq1fndLcBSujFURg9hounuZ/+xb7zYYO+M3gzOx14yN23drqeu29y94q7V0ZbHtYtYso789X8GCaUiQ6c5jTyk8DrLo8O/mRgvZm9AjgcWGFmX3D31+Vw23FQpDBYeny7p08/+Ungddd3gXf3i4GLAczsxcCFpSruoDfVoGlpZWclPL77UNReb/NnVPnplmzdfGzhg/Zk7Yf2EBwOLa3sTI/HYIyMMLNynNUnjrBuHRx7bHbsmpjkuqOTu98K3JrnbQYtgY9wUVJHvz99ghyI+rFqjtg1T3XXNNNUOffcEdaujaeT156s/dAbqxg6/IFimSGoH6umumuaKbLX16ZDx9m+XQU+bVqKFpYydvRl3KgNWf1YNdNkr6tpqjz+ePb7WCiDXwxlnmFZKKNPYLkbsP//o/konZK7+rFqHlk6wqYV4zyydITNm+Pp3kEd/OIomglbqt8b2/z/iPn/EokNG2Dt2iyuOe7IeUZumIb5eD4ZqsD3QtFMHFL93lg1FoXYe6yaqfgaBRX4XqTSCZZNzJOymkwNR4QbWBX4XkT4BEsLoXf0jeOJaWOUugg3sCrw3VA0k5aFOvqiC37jeNRUSB9U4LuhLiptC03KDqPgt/ty9gi7xqQVvfHvkQp8N9RFpW2hSdlBFfxOUYyKepgia/ZU4DtRNFNO/Rb8xvP167e6TFFMfCJ7nlTgO4lsay0D0mvBbzwP7S9TFBOfyJ4nFfhOIttay5AsVPBbvW5aXRZZsZAGkWTx5u5Dv9NKpeKzs7NDv18RkVxMTWWfxCYnh7qRNrOt7l7p9vrq4NuJZAstIgWI5NO9DjbWjg4oJiLtRPLdy+rg24lkCy0i0o46+HYi2UKLSEEiOAy1CnyzCJ40EQlABDGuIppmWvsuIt2IIMZVgW8WwZMmIgGIYD8GRTR19WgGlL2LSBJU4OsiyNNEJEABz9v1HdGY2Srg88CTAQc2ufvl/d7u0CmaEZHFCHjeLo8Mfg9wgbvfZmZHAlvN7GZ3vzOH2x6eCPI0EQlQwM1h3xGNuz/g7rfVfv41sA04pt/bHZqAP16JSAQC3mcm1wzezMaANcB3Wly20cxmzWx2bm4uz7vtj7J3EUlUbgXezJYD1wLvcPdfNV/u7pvcveLuldHR0bzutn/VanZEuAA/XolIJAJNAnIp8GZ2KFlxv8rdr8vjNocm4I9XIhKJQJOAPFbRGLAZ2ObuH+1/SEOkQwKLSB4CnWjNo4M/GXg98FIzu6P27xU53O7gBbrVFZHIBJoE9N3Bu/u3AMthLMMX6FZXRCIVWCpQ7j1ZA93qikikAksFynmwscC2siKSiMBSgXJ28IFtZUUkEYGlAuXs4APbyoqIDEI5O/jAtrIikpCAdnoqV4EP6IEXkUQFFAGXK6IJ+LCeIpKIgCLgchX4gB54EUlUQIceL1eBD+iBFxEZtHJl8CIiwxDIfF95CnwgD7iIlEAgE63liWg0wSoiwxLIfF95CnwgD7iIlEAg833lKfCBPOAiIsOSfgav7F1EihBA7Um/wAcy2SEiJRNA7Uk/olH2LiJFCKD2mLsP/U4rlYrPzs4O/X5FRGJmZlvdvdLt9dOOaALIwEREipJ2gQ8gAxOREiu4yUw7gw8gAxOREit4B8u0C7zWvotIkQpuMtMu8CIiRSq4yUwzg9fkqohIPgXezE4zs7vN7B4zuyiP2+yLJldFJBQFNpx9RzRmdjDwKWAdsAPYYmbXu/ud/d72omlyVURCUeBEax4Z/AuAe9z9JwBm9kXgTKC4At9t7jU/nz341Wr2NyIieSuw4cwjojkGuK/h/I7a7/ZjZhvNbNbMZufm5nK42zZ6+TjUHOUouxeRvNUbzgKayKFNsrr7JnevuHtldHR0cHfUS/5ercLk5L4ta+PfqtiLSB5izuCB+4FVDedX1n5XjF4+DjVHOY1/q2+AEpE8RJ7BbwGON7PjyAr7a4Bzcrjdxeln3Wnj3zZvKJTXi8hixJzBu/se4Dzg68A24Evu/qN+b7dwzbmZ4hsRWYwCM/hc9mR195uAm/K4rUUbdIfdKb5Rdy8iAUpnT9ZB79zUuBXW5KyI9KKgupDOsWiGmXNpclZEelFQXUinwBd5UB9NzopIJwVNtKYR0YQUi3SanIWwxioiw1HQRGsaHXzIsUjzlrtxrPVIR929iAxAGgU+5IOL9ZLXK84RSVcB7+80CnxM39zUKa8P+ZOIiPSngPd3GgU+Vp26e3XzImkpIGmIe5I1tQnLxokYTc6KpKWAida4O/iUI42F4ht1+CKygLgLfMiTq/3qFN9A2hs3kRRpkrVHMU2u9kt5vUjcCmjK4s3gy55JK68XiUvzMayGIN4OXhHFPsrrRcJXQOIQb4FPOX/vlfJ6kfApg+9BmfL3XimvFwmPdnSSgWgs+FNTim9EilBA6hBngVdRWjzFNyLFUAbfJRWlxVsor9fGUyQZcS6TLGC5UbJ0/HqR4SjgvRRnB68J1sFRhCMyGJpklcIpwhEZDB1NcgGKC4avU4Sj50OkewUcTbKvAm9mU2Z2l5l938y+bGZH5TWwlprzYRm+xvkPPR8i3Yswg78ZuNjd95jZ3wIXA+/uf1htaO/V4nX6RirFNyLtxZbBu/s3Gs5+G3h1f8NZgCZXw9L8fOgYOCLtRb6j05uAf8zx9iQ2WoEj0l6IOzqZ2S3AU1pcdKm7f7V2nUuBPcBVHW5nI7ARYPXq1YsarAROK3BEgrLgJKu7r3X3Z7X4Vy/ubwROB17r7t7hdja5e8XdK6Ojo7n9ByRg2olKZJ/YJlnN7DRgAvgjd38knyFJshThSJnFNskKfBI4DLjZzAC+7e5/0feoJE2KcKTMYptkdfffy2sgUkILrcIRSUkBk6xx7ckqaWt1EDnl9JKKAl7LKvASjla7cmtiVlJRwJ7fOtiYhE0Ts5KKAjJ467CycWAqlYrPzs4O/X4lAc0TsZqYlRIxs63uXun2+opoJC4Lra0XCVVs6+BFCqellhKLCNfBixRLSy0lFrGtgxcJTqs3kbp6CYHWwYv0qZulliLDVtDyXnXwkj7l9FK0gqJDFXhJn3J6KVpB30anAi/lo45ehq2gb6NTBi/lo7X0MmzK4EUKopU3MmjK4EUK0urjs3J6yZMyeJGAKKeXPCmDFwmIcnrJS4GHuFYHL9IN5fSyWAXGfSrwIt1QTi+LVVD+DirwIounrl66UVD+DsrgRRZPx72RhRT8FZPq4EXypNU30qjgGE8FXiRPOu6NNCowfwcVeJHBUk5fXgE8z7lk8GZ2gZm5menVKtJIOX15BfA8993Bm9kq4GXAz/ofjkgJqKsvh4LjGcing/8YMAF4Drclkj519ekLZIPdVwdvZmcC97v798xsoetuBDYCrF69up+7FUmPVt+kJZDJ9QULvJndAjylxUWXApeQxTMLcvdNwCaASqWibl+kkVbfpCWAeAa6KPDuvrbV783s2cBxQL17XwncZmYvcPcHcx2lSNkop49T43MUwIZ50RGNu/8AOLp+3sy2AxV3L2aXLZGU6Ng3cQrsOdI6eJFYqKsPXyDRTF1uBd7dx/K6LRFpQV192ALc2KqDF4mZVt+EI8CNrQq8SMy6WX2joj8cgcUzoAIvkpZWRSbAzjIpga2caaQCL5KSVjm9JmcHK+ANqAq8SOo0OTsY9Y3k+vXZ+YCimToVeJEy0uRs/yLYSKrAi5SRJmf7F+CkajMVeBHR5GwvAp5UbaYCLyLdTc6qo89EtOFTgReR1hTj7C+CSdVmKvAi0p2yxjj1wr5zJ1x2Wfa7SP6vuXwnq4iUQKtvoqpWYXJy/xhnaio7TUXjRqzx/xoBdfAisngLxTixRjiN42785BLT/wF18CKSp+aOvvm7ZkPu8BvH1jjuVp9cIqEOXkTy09zRN+f23UzUDrvrb5WxR7DGvRvq4EVkcJq73+YOHw7s8hvPN3f8i/0E0Ol2WmXsEXftjdTBi8jwdLPevvG0ueNfKONvPA/7fu50OxFn7AtRgReRYjUX/cbznYo/dC7c0LqIN5+22ugkwtx96HdaqVR8dnZ26PcrIonptoNPpDM3s63uXun6+irwIiJx6LXAa5JVRCRRKvAiIolSgRcRSZQKvIhIovou8GZ2vpndZWY/MrPJPAYlIiL962sdvJm9BDgTeK67P2ZmR+czLBER6Ve/HfxbgA+7+2MA7v5Q/0MSEZE89Lsn69OAPzSzDwKPAhe6+5ZWVzSzjcDG2tmHzezuPu87LyNAgIe2a0ljzV8s4wSNdRBiGSdkYz22lz9YsMCb2S3AU1pcdGnt758IvBB4PvAlM/sdb7H3lLtvAjb1MrhhMLPZXnYcKJLGmr9Yxgka6yDEMk7YO9axXv5mwQLv7ms73OFbgOtqBf27ZvYbsq3MXC+DEBGR/PWbwX8FeAmAmT0NWEI8H3dERJLWbwZ/BXCFmf0Q2A28oVU8E7jgYqMONNb8xTJO0FgHIZZxwiLGWsjBxkREZPC0J6uISKJU4EVEElXqAm9mp5nZ3WZ2j5ldVPR42jGzVWb2L2Z2Z+2QEG8vekydmNnBZna7md1Y9Fg6MbOjzOya2qE2tpnZi4oeUztm9s7ac/9DM5sxs8OLHlOdmV1hZg/V5uLqv3uimd1sZj+unf5WkWOsjanVOKdqz//3zezLZnZUkWOsazXWhssuMDM3swW/xaS0Bd7MDgY+BbwceCawwcyeWeyo2toDXODuzyTb5+AvAx4rwNuBbUUPoguXA19z92cAzyXQMZvZMcDbgIq7Pws4GHhNsaPaz5XAaU2/uwj4prsfD3yzdr5oV3LgOG8GnuXuzwH+E7h42INq40oOHCtmtgp4GfCzbm6ktAUeeAFwj7v/xN13A18kO65OcNz9AXe/rfbzr8kK0THFjqo1M1sJvBL4bNFj6cTMngCcAmwGcPfd7v6LYkfV0SHAUjM7BDgC+K+Cx7OXu/8b8D9Nvz4T+Fzt588BrxrqoFpoNU53/4a776md/TawcugDa6HNYwrwMWAC6Gp1TJkL/DHAfQ3ndxBo0WxkZmPAGuA7xY6krY+TvQB/U/RAFnAc2Q5507U46bNmtqzoQbXi7vcDHyHr2h4Afunu3yh2VAt6srs/UPv5QeDJRQ6mS28C/qnoQbRjZmcC97v797r9mzIX+OiY2XLgWuAd7v6rosfTzMxOBx5y961Fj6ULhwAnAp929zXATsKIEQ5Qy6/PJNsoPRVYZmavK3ZU3avtGxP0emwzu5QsCr2q6LG0YmZHAJcA7+nl78pc4O8HVjWcX1n7XZDM7FCy4n6Vu19X9HjaOBlYb2bbySKvl5rZF4odUls7gB3uXv8kdA1ZwQ/RWuCn7j7n7o8D1wG/X/CYFvLfZvbbALXTYI80a2ZvBE4HXhvwjpq/S7aB/17t/bUSuM3MWh0nbK8yF/gtwPFmdpyZLSGbtLq+4DG1ZGZGlhVvc/ePFj2edtz9YndfWTsg0muAf3b3IDtNd38QuM/Mnl771anAnQUOqZOfAS80syNqr4VTCXRCuMH1wBtqP78B+GqBY2nLzE4jixTXu/sjRY+nHXf/gbsf7e5jtffXDuDE2uu4rdIW+NrEynnA18neLF9y9x8VO6q2TgZeT9YR31H794qiB5WA84GrzOz7wPOADxU8npZqnzKuAW4DfkD2vg1mF3szmwH+A3i6me0ws3OBDwPrzOzHZJ9APlzkGKHtOD8JHAncXHtf/UOhg6xpM9bebyfcTyQiItKP0nbwIiKpU4EXEUmUCryISKJU4EVEEqUCLyKSKBV4EZFEqcCLiCTq/wGIx2apUfUDkgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from evals import *\n",
    "from optimization import *\n",
    "from gauss_update import *\n",
    "from kinetic_model import *\n",
    "import cubic_spline_planner\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "ax = [0.0, 6.0, 12.5, 10.0, 7.5, 3.0, -1.0]\n",
    "ay = [0.0, -3.0, -5.0, 6.5, 3.0, 5.0, -2.0]\n",
    "goal = [ax[-1], ay[-1]]\n",
    "\n",
    "cx, cy, cyaw, ck, s = cubic_spline_planner.calc_spline_course(\n",
    "        ax, ay, ds=0.2)\n",
    "\n",
    "cx = torch.Tensor(cx).view(-1,1)\n",
    "cy = torch.Tensor(cy).view(-1,1)\n",
    "\n",
    "gx = cx[1:]-cx[:-1]\n",
    "gy = cy[1:]-cy[:-1]\n",
    "\n",
    "process_model = vehicle_model(WB=2.5)\n",
    "real_car_model = vehicle_model(WB=1.5,noise=True)\n",
    "\n",
    "waypoints = torch.cat([cx[:-1],cy[:-1],gx,gy],dim=1)\n",
    "\n",
    "start=40\n",
    "\n",
    "x0,y0,gx0,gy0 = waypoints[start]\n",
    "v0 = torch.sqrt(gx0.pow(2)+gy0.pow(2))*3\n",
    "yaw0 = torch.atan2(gy0,gx0)\n",
    "delta0=torch.zeros(1)\n",
    "a0=torch.zeros(1)\n",
    "state0 = torch.cat(  [x0.view(1),y0.view(1),yaw0.view(1),v0.view(1),delta0,a0])\n",
    "control0 = torch.Tensor([0.0001,0.0001,0.001])\n",
    "\n",
    "\n",
    "\n",
    "len_horizon = 10\n",
    "EMAX=1000\n",
    "kappa=0.01\n",
    "TMAX = 20\n",
    "ql=1.\n",
    "qc=1.\n",
    "step = 0\n",
    "_controls = torch.nn.Parameter(torch.zeros(len_horizon,2))\n",
    "\n",
    "dt = torch.ones(len_horizon,1)\n",
    "#vs = torch.nn.Parameter(torch.ones(len_horizon)).data.clone()*3.0001\n",
    "vs = torch.LongTensor(torch.arange(len_horizon+1))*3\n",
    "\n",
    "print(vs)\n",
    "real_path = []\n",
    "#print(controls.size())\n",
    "for T in range(TMAX):\n",
    "    print(\"T=\")\n",
    "    print(T)\n",
    "\n",
    "    if T>0:\n",
    "        start = start_\n",
    "        \n",
    "        _controls = controls_.data.clone()\n",
    "    mpc = MPC(process_model.forward,evaluate_,len_horizon,waypoints)\n",
    "    \n",
    "    controls_dt,path = mpc.run(state0,_controls,vs,dt,start)\n",
    "    controls = controls_dt[:,:-1]\n",
    "    print(state0.size(),controls_dt[0].size())\n",
    "    state_real = real_car_model.forward(state0,controls_dt[0])\n",
    "    real_path.append(state_real)\n",
    "    #break\n",
    "    #print(path[0],state_real)\n",
    "    #path_ = path[step:]\n",
    "    controls_ = _controls.data.clone()\n",
    "    controls_[:-(step+1)] = _controls[step+1:]\n",
    "    controls_[-(step+1):]=0\n",
    "    start_ = search_(state_real,waypoints,start)\n",
    "    \n",
    "    state0 = state_real.data.clone()\n",
    "    \n",
    "#print(path[0].data.numpy())\n",
    "#print(state_real.data.numpy())\n",
    "x_ = path[:,0].data.numpy()\n",
    "y_ = path[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([20, 6])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7f0f968ed0b8>"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAbSklEQVR4nO3df4xld3nf8fdTe7eMMQtNZ4DW6/W4EaFFNNT2GNGiUgK7kZs466TJH542ER6Pum0ElLQwWxy3WE7UiMygJEhBSVbY41QxQyNDCo0owUtDq0oBdtY1EOy4oWQX1oH4Tto6dD2O1+3TP849njNn7+9z7v1+z/d8XtLo/pxzv3Pn3uc+9/l+z3PM3RERkfT8hdADEBGR6VCAFxFJlAK8iEiiFOBFRBKlAC8ikqgrQzzo/Py8Ly4uhnhoEZHGOnv27I67L4x6/yABfnFxke3t7RAPLSLSWGZ2fpz7q0QjIpIoBXgRkUQpwIuIJEoBXkQkUQrwIiKJUoAXEUmUAryISKIU4EVEEqUALyLTt7MDGxvwxBP7T3d2Lr9tZyf0aJMRZE9WEWmJnR3Y3ISLF+Hee+Fzn4NPfWrvNHfy5N51Fy/Ci18MKyswPx9o4GlQgBeR+pUD+z33wPo6HD8Ob37z3unKyt7v5NddvJgFfMhu39xUsJ+QhThk39LSkqsXjUjCNjayIH3PPeNn4/mHQx7cJ91OgszsrLsvjXp/ZfAiUo9iYM4z80kC8vw8rK3t/T7sz+rz22QoBXgRqUeebUMWhOsIxHmw39nJMvjjx7NvBy3P5EelAC8i1eSZ+/Hj2eViXb0ueaDPSz+gTH4ECvAiUk05c5+m/MNDmfxIFOBFZDKzyNzLlMmPpZYAb2YvAz4MvBZw4E53/706ti0ikZpl5l5WnMQtTu4qm9+nrgz+g8Cn3f3HzOwgcFVN2xWR2ITI3MuKK22UzfdVOcCb2UuBNwF3ALj7c8BzVbcrIpEKmbn3orp8X3Vk8NcDHWDTzF4HnAXe5e4Xi3cysxPACYAjR47U8LAiEkSxPBID1eX7qqPZ2JXAjcCvuPsNwEXgveU7ufspd19y96WFhYUaHlZEZipvCgZZAI0tS15ZydohxPLBE4E6AvwF4IK7f6F7+SGygC8iKclLM5uboUfSW7kur66U1Us07v5tM/ummb3a3Z8A3go8Vn1oIhKV2Eoz/cQ2RxBQXato3gk82F1B83Ug8leAiIysuAyxCQGzKR9EM1DLAT/c/dFuff173f2H3f1/1bFdEYlA7KWZMpVqXqA9WUVksKZmxCrVKMCLSA/lvUObGCCb+sFUIx2TVUQu17SyTC8q1SiDF5EeUsp+W1yqUYAXkT1NWzEzipQ+rMakEo2I7EmhNFPW4lKNMngR2ZNyttvCUo0CvIjsaeqKmVGk/OHVh0o0KcibQO3sjHZepKwNr48WlmqUwTdJcQIM9s4Xv3rC8PP5Uep1FBzJtal80aK/VQE+dsVA3C+Q9/rqOex8cVv5thXs26tN5Ys2/a3uPvOfm266yWWATsd9fX3vFPZf7nT2n6/zMcq3Sdr0v24UYNvHiLXK4GNUzq7z0/IEWJWvl8VtlTOa8ldYlXPS1aJyRU+Jv7YV4GNRfKENCurTUH6MYQFf0tGmckUvib+2Lcv6Z2tpacm3t7dn/rhRygP7xYtw773ZIcdie6H1m9xNMONphcSz1rE07Lkws7PuvjTq/bVMMrRiBhHr8STzDH9+/vI9HduwvC41Ke6tOqniaztBKtGEMKgcEzuVb5qv7WWZXhqWyY9KAT6EclBsUmAcVq9P9I2SlJT3Vp1UoomKAvws5cHv+PHscgoZVDlYJPpGSYI+fPtL9FuNAvwstSH4KaOPVxtef5NK9FuNAvy09au3p0oZfbza8PqrIsFkRKtopq24YiHxGfueVlb2rw7SqpvZy59zaN/rbxwJri5SBj9tbc+alNGHp+d8NAm+V2sL8GZ2BbANPOnut9a13cZK8dBndVCNfvYSDFxTkWAdvs4SzbuAx2vcXrMl+HWvFuUylZ6n6SiWwtpYGqwioTJiLRm8mR0GfhD4N8C/qGObjZXiUshp6pVdKquvTmWZySX03NVVovkl4CTwkpq211wJvThmotfXYj2H1aksM7mEnrvKAd7MbgWecvezZvbmAfc7AZwAOHLkSNWHjVdCL45gVKefnOZ+qkuoFl9HDf6NwHEzOwd8FHiLmf1G+U7ufsrdl9x9aWFhoYaHjYyWotVHdfrJ6bmqTwK1+MoZvLvfBdwF0M3g3+PuP151u42jssL0KKMfnb5B1ieB97TWwddFb6zp0Vr6wcofeHpO6tF9L+/80Ap/dAYWF6FpxYda92R198+1ag28lqKFob1j91NZZjrm59k6vMaRG+c5dgyuuw62tkIPajxqVVCF3lhhqEa/X/kDT2rR6cDqKly1u8M/fnqDq3Z3WF3Nrm8KlWgmobXucWljjV5lmak7dw4OHoSV3U02yEqCpw6sce5cc0o1CvCTUA04Lm2s0bfhbwxscRGeew42yRKHTVa4dCm7vikU4CehCdW4pZrRt631dGALC3DffbC6Os+pA2tcupRdbkr2DmDuPvMHXVpa8u3t7Zk/rrTUxkaW7a6vNzvbTeXvaJhOJyvXXP+SHeb/Q9hEwczOuvvSqPdXBj+OVDLBtmlyRq+sPbiFhW7WvtG8spgC/DhU92ymJtfom3yA9tQ08ANWAX4cDfwHSw+xZ/TK2uPUwNVKWgc/CvWZScuwdfShd5xq+2EepTYK8KNo+440qSvvKBQi4BcfQzsuxSv0h/+YVKIZhb4mp6381bv8/y7Xwesq6RS3o1p7MzRp/gYF+OFiq8/K9FUN+MXL+f173VbcjpKIZmjY/0kBfpiGfWLLFIwb8IuXof9txe00cAKvlbr/p04HzjWgw6QC/DAN+8SWGRgW8Hu9ZnrdpqDeSFtbcPLOHe5gkwdYYf3+eZaXQ4+qN+3JKiIyok4naxv89t0NNjjJGut8aG6N8+dnk8lrT9a6qPYuIiV5h8nN3b0GZAcOEG2HSS2T7EdLI0WkJO8w+afM8wHW+FPmo+4wqQy+H9XeRaRkr8MkHDhA9B0mlcGXaa9VERlgeRnOn4fPPbTDU2sbLB+Ld6cnBfgylWZEZIiFBbjh0U2u/pm4Y4VKNGUqzYjIKBoQK7RMUkSkIcZdJqkSTa5hTYRERIZRgM+p9i4ik4g4Oaxcgzeza4F/C7wCcOCUu3+w6nZnrgH1NBGJUMT9quqYZH0eeLe7P2JmLwHOmtnD7v5YDdueHfUFEZFJRJwcVi7RuPu33P2R7vnvAI8D11Td7sxE/PVKROLV6cCZM9DxeI+6VWsN3swWgRuAL/S47YSZbZvZdqfTqfNhq1HtXUTGtLWVNR07diw7/divxZko1rYO3syuBj4G/JS7/1n5dnc/BZyCbJlkXY9bWcRfr0QkPp1O1qpgdzf7ATj7jk1+9Pn46vC1BHgzO0AW3B9094/Xsc2ZUe1dRMaQd5TMgzvA1otW+MkVuDayRLFyicbMDLgPeNzdf6H6kGZI9XcRGVPeUbLoT/7vPC/61/HV4euowb8R+AngLWb2aPfnB2rY7vSp/i4iY8o7Ss7NwaFD2WmsHSUrl2jc/b8CVsNYZk/1dxGZwPIyHD2alWv2HZc1sgMFtbvZmOrvIjKhhYUeWXtkOz21M8BH9ikrIomIrCrQzl40qr2LyDTMx7XTUzsz+Mg+ZUVEpqGdGXxkn7IikpCIll+3K8BH9MSLSKIiKgG3q0QT2Qy3iCQoohJwuwJ8RE+8iDRTp9Nj/XtRRMuv21WiUe1dRCood5Hc2go9osHaE+BVfxeRCopdJJ9+OjtdXc2uv0wk8aY9AT6iiQ8RaZ68i2TRgQPZ9ZeJJN60pwav+ruIVNCri+SlS9n1l4kk3rQng1f9XUQqGKuLZCTxpj0ZvIhIRX27SEYq/Qw+kskOEUnDwgLcfPMIwT2C2JN+gI9kskNEWiaC2JN+iSaSyQ4RaZkIYo+5+8wfdGlpybe3t2f+uCIiTWZmZ919adT7p12iiaAGJiISStoBPoIamIi0WOAkM+0afAQ1MBFpscAdbNMO8BF1dROR5hvaSbIscJKZZolGtXcRqdlEnSQD79GaZoBX7V1EajRWJ8mygAlnLSUaM7sF+CBwBfBhd39/HdudmGrvIlKjvJPk7u7edXknyaGlmoB1+MoB3syuAD4EHAMuAGfM7JPu/ljVbU9MtXcRqdFYnSTLAiacdZRoXg98zd2/7u7PAR8Fbqthu9OnWr2IjGCsTpJlAevwdQT4a4BvFi5f6F63j5mdMLNtM9vujFS4mtA4Qbtcq1fAF5E+lpfh/Hk4fTo7XV4OPaLhZjbJ6u6n3H3J3ZcWptljc5wJ1pUVWF/f++pU/F0FexEpGbmTZFHDJ1mfBK4tXD7cvS6Mcepd5Vp98XfLEyM7O9l1KyvBm/iLSIM0eZIVOAO8ysyuJwvstwP/sIbtTqbKBGvxd8sfFIH3SBORhgo4yVo5wLv782b2DuB3yJZJ3u/uX608snHVnWEPyu6n8XgikqaAq/pqWQfv7p8CPlXHtiY27Qy7/E8qPl5e0lGwF5GIpNOLZtZfg1SvF5FRBYoJ6QT4WX8NGrVer+xeRALN4aUT4EMaZzWOiDTa2B0lIdhEaxrNxmJbs17cc6281j62sYrIyCbqKAnB9mZNI8DH3D2y/I/VzlQijVSpo2QgaZRomtQ9UpOzIo1UqaOkJlkraFL3SE3OijRSpY6SmmRtIU3OijRG3lFydTXL3C9dGqOjZKAqg7n7TB8QYGlpybe3t+vZWKpljfLflerfKdIwE62iqYmZnXX3pVHv3/xJ1pgnWKsYNDkLmqAVCWSijpIQ5D3b/BJNkyZYq1C9XqTZApRdmx/gmzTBWoXq9SLNFiAZbW4NXjXpParXi7RCe2rwqdbeJ6F6vYj00NwSTVtq75PQwUpE4hPgm3VzA3xbau+TGFSvV/lGJAxNsspUFAP+xobaI4jUZKw18QGqDs2swaumPLlyd0vV60UmMnZnyQAdJZuZwaumPLlhx5rVcysyVLGzZN58bHUVjh6d/d6tgzQzwGuCtT46uLjI2CbqLBngvdTMEk2g5vmtoCWXIkNN1FkywNLuZmbwMjsq4YhcZqLOktqTdQiVC8Ir/g9A/w9ptVl3lpzpnqxmtmFmf2BmXzaz3zKzl1XZ3lDaezW8YglH5RtpubE6Szawm+TDwF3u/ryZ/TxwF/Avqw+rD02uxkXlG5HRNW1HJ3f/TOHi54EfqzacIbT3aly0AkdkdAES1DonWe8E/l2/G83sBHAC4MiRIzU+rESjHPCV0YvsCZCgDg3wZnYaeGWPm+52909073M38DzwYL/tuPsp4BRkk6wTjVaaRRm9SFBDA7y7Hx10u5ndAdwKvNVDLMmReA3L6BXwpU2a1k3SzG4BTgJ/z92fqWdIkixNykqCRl4q2bRJVuCXgb8IPGxmAJ93939aeVSSJk3KSmK2trKdnQ4ezPZsve8+WF7uc2ft6CStlrcyXl9XRi/R63SyLpLFfjRzc3D+/PR2emrPIfskPeVWxqCdpyRaecOxorzhWE8BXssK8BKPXk3ktLesRGrshmNqNiZSoolZidTYDcdUgxcZojwRq4lZCWyWDcdUg5e0DetXLzJjIzcca2CzMZGwtNRSmqKB6+BFwlL/G2mKhjcbEwmv15tIWb3EIECzMdXgJS2jLLUUmbVAy3uVwUv6VKeX0AKVDhXgJX2q08uUjLxEMtDR6FSikfYpt0TQ3rEyga2trBfNsWPZ6dbWgDv3Kh3OgAK8tI/W0ktFnU62B+vuLjz9dHa6uppd31OgJEIBXkRNzmRMYzcaC5REqAYv0mv5mur0MsDYjcZUgxeJiOr0MkDeaGxuDg4dyk4HNhpTDV4kIqrTyxDLy9nBPU6fzk77HskpYHKgEo3IKLSHrPSwsDBCk7GA5T4FeJFRqE4vkwpUfwcFeJHJKauXYQK/HlSDF5mU+t60UqcDZ84MWPNeFPj1oAxepE7qe5O0ra1sh6aDB7NlkvfdN2ByFYKWZ0CH7BOZro2NLINbX1edvuE6nawlwe7u3nVzc9kKmmkfqi+nQ/aJxER7ySZj7L1XI/g/1xLgzezdZuZmpu+gIkWq0ydj7L1XI/g/V67Bm9m1wPcD36g+HJEW0OqbRsr3Xl1dzTL3S5eG7L0auP4O9Uyy/iJwEvhEDdsSSZ/W1DfW8jIcPTpCD/hIPrArBXgzuw140t2/ZGbD7nsCOAFw5MiRKg8rkp4RVt+MfHAJqVX5eY9979WioQHezE4Dr+xx093AT5OVZ4Zy91PAKchW0YwxRpH0DTnq1NjL86QWEz/vEZRnoMIySTP7m8BngWe6Vx0G/hh4vbt/e9DvapmkyBCFDL7j81x3HVy1u8MKm2yywjNz8zNdntdGEy2LnHJpZmbLJN39K+7+cndfdPdF4AJw47DgLiIjKKy+yZfnrbDJBidZYXPw8jypxdjLIiGKlTNF2pNVJHL58rxNsq/7m6zsLc+LZDIvRWMvi4RoSjO52nZ06mby2nNDpGb58rxn5uY5dWiNZ+bm95bnlTPGCHauabJin5mxD+oR4YetMniRBui7PK+cMfZYvaHVN6PpN6E60rJIiGblTJF60YikpJRFbm3ByTt3uINNHmCF9fvntfqmh1r6zMwgg1cvGpE2K0zOdjpZRnr7s5v87LMnuf3ZTVZXR2xz2zITTajm8rIYBDnu6iAq0YgkKg9am7t7k7N50Fqw+OrFs1QuW000oZqLsDSTU4AXSVQetJ5mng+QBZ65PGhFHJSmrV+tfaw+M7BXkjl+PLscycqZIgV4kUQNbI7V0tYIedlqd3ev3r66mk2kjjWhCo34kFSAF0lY36A1QmuEFCdn87JVcTL1hbLVqH1mGpC55xTgRRI3UtAqZPR5lvv2Zzf5WU7yLLC6usbRo83J5Pt9+6hUa881IHPPKcCLyL6M/tyZ3pOzFx7dYeHR4ROzoUs7gxqEjd3TvahBmfsL3H3mPzfddJOLSJyeesp9bs4d9n7m5ty/87717ML6+t6dO53scqfj7u4f+Uh235e+NDv9yEemM74vfjE7HXXs5fsO2kZf6z3+/hkDtn2MWKt18CKyT79d9K9+Z4/jyxZaJXQ6Wd3+7bsbXPn0Dru7jLTuvtgeYJitrWyHpGPHstOtrf23j7qefWEBbr55jMx9YyPL3Mt/f+zG+TSo60cZvEj8RspyCxn8F7/o/q9elGW572Hdwf3QoWwb/YyT8Y+SnY+awY8lgsw9x5gZvGrwItLTSJOzhdr9osMDrPAse50vDz23w9/47U24/vK6/aAli70ed9gKmHzME9fYi4rLRiPrEDkOlWhEpBYLC7B+/zwfmlvj0qF55ubgkz+yydU/07vb5blz8Mord3gPG/xlsu6Xg9oDjLoCZnk56yFz+nR2OvLyzmInzmKXzkL7h6ZRBi8itbls3b2twA307Ha5eMcat+9u8nNklz/A2sCMf5zsfKRvH7k8oF+8CPfem13X4Kx9n3HqOXX9qAYv0lKlVTcP/WrH77py3Rev7vjcnPsjy6V6d+n+3un4d9637o98ppPV1cu393usfufd92rs99zTf1uRQDV4EYlWaQ/aH/0n87zpH6zxI+eGZ/ysrcFmVvK5YR04tgYbpduLtfPi70Lv82tr+7P1BpZhBlGAF5Gg9pdTSi0UyqWSYafFoN6rzNLrfLltQ0J0wA8RSUeEh82r07gH/FAGLyLpSDgbn4SWSYqIJEoBXkQkUQrwIiKJqhzgzeydZvYHZvZVM1uvY1AiIlJdpUlWM/s+4Dbgde7+52b28nqGJSIiVVXN4H8SeL+7/zmAuz9VfUgiIlKHqgH+e4C/a2ZfMLP/bGY31zEoERGpbmiJxsxOA6/scdPd3d//LuANwM3Ab5rZX/Mee0+Z2QngRPfi/zGzJyYedb3modvKLn4aa/2aMk7QWKehKeOEbKzXjfMLlfZkNbNPAz/v7r/bvfw/gDe4+wjHZomDmW2Ps2dYSBpr/ZoyTtBYp6Ep44TJxlq1RPPvge/rPvj3AAdpzqehiEjSqrYquB+438x+H3gOeFuv8oyIiMxepQDv7s8BP17TWEI5FXoAY9BY69eUcYLGOg1NGSdMMNYg3SRFRGT61KpARCRRCvAiIolqdYA3s1vM7Akz+5qZvTf0ePoxs2vN7HfN7LFuz593hR7TIGZ2hZn9NzP77dBjGcTMXmZmD3V7KT1uZn879Jj6MbN/3v3f/76ZbZnZi0KPCcDM7jezp7oLLfLrvsvMHjazP+ye/qWQY8z1GetG9///ZTP7LTN7Wcgx5nqNtXDbu83MzWzoEU1aG+DN7ArgQ8DfB14DLJvZa8KOqq/ngXe7+2vIdip7e8RjBXgX8HjoQYzgg8Cn3f2vA68j0jGb2TXAPwOW3P21wBXA7WFH9YIHgFtK170X+Ky7vwr4bPdyDB7g8rE+DLzW3b8X+O/AXbMeVB8PcPlYMbNrge8HvjHKRlob4IHXA19z9693VwN9lKxxWnTc/Vvu/kj3/HfIAtE1YUfVm5kdBn4Q+HDosQxiZi8F3gTcB9mKMHf/32FHNdCVwJyZXQlcBfxx4PEA4O7/BfifpatvA369e/7XgR+e6aD66DVWd/+Muz/fvfh54PDMB9ZDn+cV4BeBk8BIq2PaHOCvAb5ZuHyBSINmkZktkh13/gthR9LXL5G9AP9f6IEMcT3QATa75aQPm9mLQw+qF3d/EvgAWdb2LeBpd/9M2FEN9Ap3/1b3/LeBV4QczBjuBP5j6EH0Y2a3AU+6+5dG/Z02B/jGMbOrgY8BP+XufxZ6PGVmdivwlLufDT2WEVwJ3Aj8irvfAFwknlLCPt0a9m1kH0p/FXixmTVi/5Pujo/Rr8U2s7vJSqEPhh5LL2Z2FfDTwPvG+b02B/gngWsLlw93r4uSmR0gC+4PuvvHQ4+njzcCx83sHFnJ6y1m9hthh9TXBeCCu+ffhB4iC/gxOgr8kbt33P0S8HHg7wQe0yB/YmZ/BaB7GnUbcTO7A7gV+EcR74n/3WQf8F/qvr8OA4+YWa9GkC9oc4A/A7zKzK43s4Nkk1afDDymnszMyGrFj7v7L4QeTz/ufpe7H3b3RbLn8z+5e5SZprt/G/immb26e9VbgccCDmmQbwBvMLOruq+FtxLphHDXJ4G3dc+/DfhEwLEMZGa3kJUUj7v7M6HH04+7f8XdX+7ui9331wXgxu7ruK/WBvjuxMo7gN8he7P8prt/Neyo+noj8BNkGfGj3Z8fCD2oBLwTeNDMvgz8LeDnAo+np+63jIeAR4CvkL1vo9jF3sy2gN8DXm1mF8xsFXg/cMzM/pDs28f7Q44x12esvwy8BHi4+7761aCD7Ooz1vG3E+83EhERqaK1GbyISOoU4EVEEqUALyKSKAV4EZFEKcCLiCRKAV5EJFEK8CIiifr/v+HeTVk7zTkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "real_path_ = torch.cat(real_path,dim=0).view(-1,6)\n",
    "\n",
    "print(real_path_.size())\n",
    "x_ = real_path_[:,0].data.numpy()\n",
    "y_ = real_path_[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_ = path[:,0].data.numpy()\n",
    "y_ = path[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(x_,y_,c='b',s=40)\n",
    "plt.scatter(cx,cy,c='r',s=1)\n",
    "wx,wy = cx[vs+start].data.cpu().numpy(),cy[vs+start].data.cpu().numpy()\n",
    "wx,wy=s0[:,0],s0[:,1]\n",
    "plt.scatter(wx,wy,c='k',s=40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ql=1.\n",
    "qc=1.\n",
    "def evaluate_(path,refs,vs):\n",
    "    #print(path.size(),refs.size())\n",
    "    return error_deviation_parallel_(path,refs,ql,qc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate=0.01\n",
    "\n",
    "class MPC:\n",
    "    def __init__(self,model,evaluate,len_horizon,waypoints):\n",
    "\n",
    "        self.waypoints = waypoints\n",
    "        self.len_horizon = len_horizon\n",
    "        self.model = model\n",
    "        self.evaluate = evaluate\n",
    "\n",
    "        \n",
    "    def run(self,state_init,controls_init,vs,dt,start):\n",
    "        start = torch.LongTensor([start])\n",
    "        state_init = state_init.data.clone()\n",
    "        controls = torch.nn.Parameter(controls_init.data.clone())\n",
    "        vs = torch.LongTensor(vs.data.clone())\n",
    "        if len(dt)==1:\n",
    "            dt = torch.ones(self.len_horizon).view(-1,1)*dt\n",
    "        opt =  torch.optim.Adam([controls],lr=learning_rate)\n",
    "        refs = waypoints[vs[1:]+start]\n",
    "        '''\n",
    "        print(\"start\")\n",
    "        print(start)\n",
    "        print(state_init)\n",
    "        print(controls)\n",
    "        print(vs)\n",
    "        print(dt)\n",
    "        \n",
    "        print(refs)\n",
    "        '''\n",
    "        for epoch in range(EMAX):\n",
    "            controls_dt = torch.cat([controls,dt],dim=1)\n",
    "            state  = state_init\n",
    "            path = []\n",
    "            for t in range(len_horizon):\n",
    "                state = self.model(state,controls_dt[t])\n",
    "                path.append(state.view(1,-1))\n",
    "            path = torch.cat(path,dim=0)\n",
    "     \n",
    "            opt.zero_grad()\n",
    "            \n",
    "            loss = self.evaluate(path,refs,vs) \n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            if epoch % 500 ==0:\n",
    "                print(loss.data.numpy())\n",
    "        controls_dt = torch.cat([controls,dt],dim=1)\n",
    "        return controls_dt,path    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
